{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c235d92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')\n",
    "\n",
    "from medclip import MedCLIPModel, MedCLIPVisionModelViT\n",
    "from medclip import MedCLIPProcessor\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "88263d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare for the demo image and texts\n",
    "processor = MedCLIPProcessor()\n",
    "\n",
    "image = Image.open('./example_data/view1_frontal.jpg')\n",
    "inputs = processor(\n",
    "    text=[\"lungs remain severely hyperinflated with upper lobe emphysema\", \n",
    "        \"opacity left costophrenic angle is new since prior exam ___ represent some loculated fluid cavitation unlikely\"], \n",
    "    images=image, \n",
    "    return_tensors=\"pt\", \n",
    "    padding=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "18190d5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zifengw2/miniconda3/envs/medclip/lib/python3.8/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2228.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']\n",
      "- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load model weight from: ./data/MedCLIP/checkpoints/vision_text_pretrain/25000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MedCLIPModel(\n",
       "  (vision_model): MedCLIPVisionModelViT(\n",
       "    (model): SwinModel(\n",
       "      (embeddings): SwinEmbeddings(\n",
       "        (patch_embeddings): SwinPatchEmbeddings(\n",
       "          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))\n",
       "        )\n",
       "        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (encoder): SwinEncoder(\n",
       "        (layers): ModuleList(\n",
       "          (0): SwinStage(\n",
       "            (blocks): ModuleList(\n",
       "              (0): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (key): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (value): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=96, out_features=384, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=384, out_features=96, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (1): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (key): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (value): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=96, out_features=96, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((96,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=96, out_features=384, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=384, out_features=96, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "            )\n",
       "            (downsample): SwinPatchMerging(\n",
       "              (reduction): Linear(in_features=384, out_features=192, bias=False)\n",
       "              (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (1): SwinStage(\n",
       "            (blocks): ModuleList(\n",
       "              (0): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (key): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (value): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=192, out_features=768, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=768, out_features=192, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (1): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (key): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (value): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=192, out_features=192, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((192,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=192, out_features=768, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=768, out_features=192, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "            )\n",
       "            (downsample): SwinPatchMerging(\n",
       "              (reduction): Linear(in_features=768, out_features=384, bias=False)\n",
       "              (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (2): SwinStage(\n",
       "            (blocks): ModuleList(\n",
       "              (0): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (1): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (2): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (3): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (4): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (5): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (key): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (value): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=384, out_features=384, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((384,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=384, out_features=1536, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=1536, out_features=384, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "            )\n",
       "            (downsample): SwinPatchMerging(\n",
       "              (reduction): Linear(in_features=1536, out_features=768, bias=False)\n",
       "              (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (3): SwinStage(\n",
       "            (blocks): ModuleList(\n",
       "              (0): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "              (1): SwinLayer(\n",
       "                (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (attention): SwinAttention(\n",
       "                  (self): SwinSelfAttention(\n",
       "                    (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                  (output): SwinSelfOutput(\n",
       "                    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                    (dropout): Dropout(p=0.0, inplace=False)\n",
       "                  )\n",
       "                )\n",
       "                (drop_path): SwinDropPath()\n",
       "                (layernorm_after): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "                (intermediate): SwinIntermediate(\n",
       "                  (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "                  (intermediate_act_fn): GELUActivation()\n",
       "                )\n",
       "                (output): SwinOutput(\n",
       "                  (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "                  (dropout): Dropout(p=0.0, inplace=False)\n",
       "                )\n",
       "              )\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      (pooler): AdaptiveAvgPool1d(output_size=1)\n",
       "    )\n",
       "    (projection_head): Linear(in_features=768, out_features=512, bias=False)\n",
       "  )\n",
       "  (text_model): MedCLIPTextModel(\n",
       "    (model): BertModel(\n",
       "      (embeddings): BertEmbeddings(\n",
       "        (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
       "        (position_embeddings): Embedding(512, 768)\n",
       "        (token_type_embeddings): Embedding(2, 768)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (encoder): BertEncoder(\n",
       "        (layer): ModuleList(\n",
       "          (0): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (1): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (2): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (3): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (4): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (5): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (6): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (7): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (8): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (9): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (10): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (11): BertLayer(\n",
       "            (attention): BertAttention(\n",
       "              (self): BertSelfAttention(\n",
       "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): BertSelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): BertIntermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): BertOutput(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (pooler): BertPooler(\n",
       "        (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "        (activation): Tanh()\n",
       "      )\n",
       "    )\n",
       "    (projection_head): Linear(in_features=768, out_features=512, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# init model from pretrained checkpoint\n",
    "model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT, checkpoint='./data/MedCLIP/checkpoints/vision_text_pretrain/25000')\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fcb39f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(**inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de8ccc0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['img_embeds', 'text_embeds', 'logits', 'loss_value', 'logits_per_text'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aa1d87ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs['logits'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0d81293",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2560, 0.6394]], device='cuda:0', grad_fn=<TBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs['logits']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e62cd698",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
